// Copyright (c) 2023 Huawei Device Co., Ltd.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! `Resolver` trait and `DefaultDnsResolver` implementation.

use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::io::Error;
use std::net::{SocketAddr, ToSocketAddrs};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use std::vec::IntoIter;

use crate::runtime::JoinHandle;

const DEFAULT_TTL: Duration = Duration::from_secs(60);
const MAX_ENTRIES_LEN: usize = 30000;

/// `SocketAddr` resolved by `Resolver`.
pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Sync + Send>;
/// Possible errors that this resolver may generate when attempting to
/// resolve.
pub type StdError = Box<dyn std::error::Error + Send + Sync>;
/// Futures generated by this resolve when attempting to resolve an address.
pub type SocketFuture<'a> =
    Pin<Box<dyn Future<Output = Result<Addrs, StdError>> + Sync + Send + 'a>>;

/// `Resolver` trait used by `async_impl::connector::HttpConnector`. `Resolver`
/// provides asynchronous dns resolve interfaces.
pub trait Resolver: Send + Sync + 'static {
    /// resolve authority to a `SocketAddr` `Future`.
    fn resolve(&self, authority: &str) -> SocketFuture;
}

/// `SocketAddr` resolved by `DefaultDnsResolver`.
pub struct ResolvedAddrs {
    iter: IntoIter<SocketAddr>,
}

impl ResolvedAddrs {
    pub(super) fn new(iter: IntoIter<SocketAddr>) -> Self {
        Self { iter }
    }

    // The first ip in the dns record is the preferred addrs type.
    pub(super) fn split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>) {
        // get preferred address family type.
        let is_ipv6 = self
            .iter
            .as_slice()
            .first()
            .map(SocketAddr::is_ipv6)
            .unwrap_or(false);
        self.iter
            .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == is_ipv6)
    }
}

impl Iterator for ResolvedAddrs {
    type Item = SocketAddr;

    fn next(&mut self) -> Option<Self::Item> {
        self.iter.next()
    }
}

/// Futures generated by `DefaultDnsResolver`.
pub struct DefaultDnsFuture {
    inner: JoinHandle<Result<ResolvedAddrs, Error>>,
}

impl Future for DefaultDnsFuture {
    type Output = Result<Addrs, StdError>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        Pin::new(&mut self.inner).poll(cx).map(|res| match res {
            Ok(Ok(addrs)) => Ok(Box::new(addrs) as Addrs),
            Ok(Err(err)) => Err(Box::new(err) as StdError),
            Err(err) => Err(Box::new(Error::new(io::ErrorKind::Interrupted, err)) as StdError),
        })
    }
}

/// Default dns resolver used by the `Client`.
/// DefaultDnsResolver provides DNS resolver with caching machanism.
pub struct DefaultDnsResolver {
    manager: DnsManager,     // Manages DNS cache
    connector: DnsConnector, // Performing DNS resolution
    ttl: Duration,           // Time-to-live for the DNS cache
}

impl Default for DefaultDnsResolver {
    // Default constructor for `DefaultDnsResolver`, with a default TTL of 60
    // seconds.
    fn default() -> Self {
        DefaultDnsResolver {
            manager: DnsManager::default(),
            connector: DnsConnector {},
            ttl: DEFAULT_TTL, // Default TTL set to 60 seconds
        }
    }
}

impl DefaultDnsResolver {
    /// Create a new DefaultDnsResolver. And TTL is Time to live for cache.
    ///
    /// # Examples
    ///
    /// ```
    /// use std::time::Duration;
    ///
    /// use ylong_http_client::async_impl::DefaultDnsResolver;
    ///
    /// let res = DefaultDnsResolver::new(Duration::from_secs(1));
    /// ```
    pub fn new(ttl: Duration) -> Self {
        DefaultDnsResolver {
            manager: DnsManager::new(),
            connector: DnsConnector {},
            ttl, // Set TTL through the passed parameters
        }
    }
}

#[derive(Default)]
struct DnsManager {
    // Cache storing authority and DNS results
    map: Mutex<HashMap<String, DnsResult>>,
}

impl DnsManager {
    // Creates a new `DnsManager` instance with an empty cache
    fn new() -> Self {
        DnsManager {
            map: Mutex::new(HashMap::new()),
        }
    }

    // Cleans expired DNS cache entries by retaining only valid ones
    fn clean_expired_entries(&self) {
        let mut map_lock = self.map.lock().unwrap();
        if map_lock.len() > MAX_ENTRIES_LEN {
            map_lock.retain(|_, result| result.inner.lock().unwrap().is_valid());
        }
    }
}

struct DnsResult {
    inner: Arc<Mutex<DnsResultInner>>,
}

impl DnsResult {
    // Creates a new DNS result with the given addresses and expiration time
    fn new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self {
        DnsResult {
            inner: Arc::new(Mutex::new(DnsResultInner {
                addr,
                expiration_time,
            })),
        }
    }
}

#[derive(Clone)]
struct DnsResultInner {
    addr: Vec<SocketAddr>,    // List of resolved addresses for the authority
    expiration_time: Instant, // Expiration time for the cache entry
}

impl DnsResultInner {
    // Checks if the DNS result is still valid
    fn is_valid(&self) -> bool {
        self.expiration_time > Instant::now()
    }
}

impl Default for DnsResultInner {
    // Default constructor for `DnsResultInner`, with an empty address list and 60
    // seconds expiration
    fn default() -> Self {
        DnsResultInner {
            addr: vec![],
            expiration_time: Instant::now() + Duration::from_secs(60),
        }
    }
}

struct DnsConnector {}

impl DnsConnector {
    // Resolves the authority to a list of socket addresses
    fn get_socket_addrs(&self, authority: &str) -> Result<Vec<SocketAddr>, io::Error> {
        authority
            .to_socket_addrs()
            .map(|addrs| addrs.collect())
            .map_err(|err| io::Error::new(io::ErrorKind::Other, err))
    }
}

impl Resolver for DefaultDnsResolver {
    fn resolve(&self, authority: &str) -> SocketFuture {
        let authority = authority.to_string();
        self.manager.clean_expired_entries();
        Box::pin(async move {
            let mut map_lock = self.manager.map.lock().unwrap();
            if let Some(addrs) = map_lock.get(&authority) {
                let lock_inner = addrs.inner.lock().unwrap();
                if lock_inner.is_valid() {
                    return Ok(Box::new(lock_inner.addr.clone().into_iter()) as Addrs);
                }
            }
            match self.connector.get_socket_addrs(&authority) {
                Ok(addrs) => {
                    let dns_result = DnsResult::new(addrs.clone(), Instant::now() + self.ttl);
                    map_lock.insert(authority, dns_result);
                    Ok(Box::new(addrs.into_iter()) as Addrs)
                }
                Err(err) => Err(Box::new(err) as StdError),
            }
        })
    }
}

#[cfg(feature = "tokio_base")]
#[cfg(test)]
mod ut_dns_cache {
    use super::*;

    /// UT test cases for `DefaultDnsResolver::resolve`.
    ///
    /// # Brief
    /// 1. Verify the first DNS result is cached when connected to Internet or
    ///    return error when without Internet.
    /// 2. Verify the second DNS result as same as the first one.
    #[tokio::test]
    async fn ut_default_dns_resolver() {
        let domain = "example.com:0";
        let resolver = DefaultDnsResolver::new(std::time::Duration::from_millis(100));
        let result1 = resolver.resolve(domain).await;
        let result2 = resolver.resolve(domain).await;
        let result1 = result1
            .map(|a| a.collect::<Vec<_>>())
            .err()
            .map(|e| e.to_string());
        let result2 = result2
            .map(|a| a.collect::<Vec<_>>())
            .err()
            .map(|e| e.to_string());
        assert_eq!(result1, result2);
    }
}

#[cfg(feature = "ylong_base")]
#[cfg(test)]
mod ut_dns_cache {
    use super::*;

    /// UT test cases for `DefaultDnsResolver::resolve`.
    ///
    /// # Brief
    /// 1. Verify the first DNS result is cached when connected to Internet or
    ///    return error when without Internet.
    /// 2. Verify the second DNS result as same as the first one.
    #[test]
    fn ut_default_dns_resolver() {
        ylong_runtime::block_on(ut_default_dns_resolver_async());
    }

    async fn ut_default_dns_resolver_async() {
        let domain = "example.com:0";
        let resolver = DefaultDnsResolver::new(std::time::Duration::from_millis(100));
        let result1 = resolver.resolve(domain).await;
        let result2 = resolver.resolve(domain).await;
        let result1 = result1
            .map(|a| a.collect::<Vec<_>>())
            .err()
            .map(|e| e.to_string());
        let result2 = result2
            .map(|a| a.collect::<Vec<_>>())
            .err()
            .map(|e| e.to_string());
        assert_eq!(result1, result2);
    }
}
